{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import wandb\n",
    "\n",
    "import os\n",
    "import shutil\n",
    "\n",
    "import torch\n",
    "import scipy\n",
    "import numpy as np\n",
    "import ot\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import tabulate\n",
    "import tqdm\n",
    "\n",
    "import itertools\n",
    "\n",
    "from matplotlib import rc\n",
    "\n",
    "rc('text', usetex=True)\n",
    "rc('font',**{'family':'serif','serif':['Computer Modern Roman']})\n",
    "\n",
    "%matplotlib inline\n",
    "palette = sns.color_palette()\n",
    "\n",
    "api = wandb.Api()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note:** this notebook uses outdated experiment config names; you may need to change experiment name regexps to run this code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def query_runs(api, exp_name_regex, step):    \n",
    "    runs = api.runs(\n",
    "        'timgaripov/support_alignment_v2', # Use your wandb project name here\n",
    "        filters={\n",
    "            '$and': [\n",
    "                {'config.experiment_name': {'$regex': exp_name_regex}},\n",
    "                {'state': 'finished'},\n",
    "                {'summary_metrics.step': {'$eq': step}},\n",
    "            ]\n",
    "        },\n",
    "        order='config.config/training/seed.value'\n",
    "    )\n",
    "    return runs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9\n",
      "0_usps_mnist/features_lenetnorelu_v3/seed_1/c_3/s_alpha_15/d_2/dann_zero\n",
      "0_usps_mnist/features_lenetnorelu_v3/seed_2/c_3/s_alpha_15/d_2/dann_zero\n",
      "0_usps_mnist/features_lenetnorelu_v3/seed_2/c_3/s_alpha_15/d_2/dann_zero\n",
      "0_usps_mnist/features_lenetnorelu_v3/seed_3/c_3/s_alpha_15/d_2/dann_zero\n",
      "0_usps_mnist/features_lenetnorelu_v3/seed_3/c_3/s_alpha_15/d_2/dann_zero\n",
      "0_usps_mnist/features_lenetnorelu_v3/seed_4/c_3/s_alpha_15/d_2/dann_zero\n",
      "0_usps_mnist/features_lenetnorelu_v3/seed_4/c_3/s_alpha_15/d_2/dann_zero\n",
      "0_usps_mnist/features_lenetnorelu_v3/seed_5/c_3/s_alpha_15/d_2/dann_zero\n",
      "0_usps_mnist/features_lenetnorelu_v3/seed_5/c_3/s_alpha_15/d_2/dann_zero\n"
     ]
    }
   ],
   "source": [
    "api = wandb.Api()\n",
    "# Updated regex pattern: usps_mnist_3c/lenet_2d/seed_[1-5]/s_alpha_15/dann_zero[^0_]\n",
    "runs = query_runs(api, '0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2/dann_zero[^0_]', 30000)\n",
    "\n",
    "print(len(runs))\n",
    "for run in runs:\n",
    "    print(run.config['experiment_name'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_distances(x, y):\n",
    "    x_np = x.numpy()\n",
    "    y_np = y.numpy()\n",
    "    \n",
    "    xy_combined = np.concatenate((x, y))\n",
    "    mean = np.mean(xy_combined, axis=0)    \n",
    "    \n",
    "    xy_combined -= mean[None, :]\n",
    "    avg_norm = np.mean(np.linalg.norm(xy_combined, axis=1))\n",
    "    \n",
    "    x_np -= mean[None, :]\n",
    "    y_np -= mean[None, :]\n",
    "    \n",
    "    x_np /= avg_norm\n",
    "    y_np /= avg_norm\n",
    "    \n",
    "    \n",
    "    M = ot.dist(x_np, y_np, metric='euclidean')\n",
    "    w1 = ot.emd2([], [], M)\n",
    "    ssd1 = 0.5 * (np.mean(np.min(M, axis=1)) + np.mean(np.min(M, axis=0)))\n",
    "    h1 = np.maximum(np.max(np.min(M, axis=1)), np.max(np.min(M, axis=0)))\n",
    "    return w1, ssd1, h1\n",
    "\n",
    "def feature_distances(run, step):\n",
    "    fname = f'features_step_{step:06d}.pkl'\n",
    "    run.file(fname).download('./data_feature_distance', replace=True)\n",
    "    feature_info = torch.load(f'./data_feature_distance/{fname}')\n",
    "            \n",
    "    w1, ssd1, hd1 = compute_distances(feature_info['features_src_tr'], feature_info['features_trg_tr'])\n",
    "    \n",
    "    os.remove(f'./data_feature_distance/{fname}')\n",
    "    return w1, ssd1, hd1\n",
    "    \n",
    "\n",
    "class ResultsTable(object):\n",
    "    def mean_std_fn(vals):    \n",
    "        mean = np.mean(vals)\n",
    "        std = np.std(vals)\n",
    "        return f'{mean:.2f} ({std:.2f}) [{len(vals)}]'\n",
    "    def median_iqr_fn(vals):    \n",
    "        median = np.median(vals)\n",
    "        iqr = scipy.stats.iqr(vals)\n",
    "        return f'{median:.2f} ({iqr:.2f}) [{len(vals)}]'\n",
    "    def quant_fn(vals):    \n",
    "        quants = np.quantile(vals, np.array([0.25, 0.5, 0.75]))\n",
    "        return f'{quants[0]:05.2f} {quants[1]:05.2f} {quants[2]:05.2f} [{len(vals)}]'\n",
    "    \n",
    "    def quant_latex_fn(vals):    \n",
    "        quants = np.quantile(vals, np.array([0.25, 0.5, 0.75]))\n",
    "        return f'$ {quants[1]:04.1f}_{{~{quants[0]:04.1f}}}^{{~{quants[2]:04.1f}}} $'\n",
    "    \n",
    "    def median_x_fn(vals):    \n",
    "        median = np.median(vals)        \n",
    "        return f'{median:.4f} [{len(vals)}]'\n",
    "    \n",
    "    def median_x_fn(vals):    \n",
    "        median = np.median(vals)\n",
    "        iqr = scipy.stats.iqr(vals)\n",
    "        return f'{median:.4f} [{len(vals)}]'\n",
    "    \n",
    "    def quant_latex_x_fn(vals):    \n",
    "        quants = np.quantile(vals, np.array([0.25, 0.5, 0.75]))\n",
    "        return f'$ {quants[1]:05.2f}_{{~{quants[0]:05.2f}}}^{{~{quants[2]:05.2f}}} $'\n",
    "    \n",
    "    def quant_or_fn(vals):    \n",
    "        quants = np.quantile(vals, np.array([0.25, 0.5, 0.75]))\n",
    "        return f'$ {quants[1]:0.2f}_{{{quants[0]:0.2f}}}^{{{quants[2]:0.2f}}} $'\n",
    "    \n",
    "    def __init__(self, prefix, algorithms, num_steps):\n",
    "        api = wandb.Api()        \n",
    "        self.algorithms = algorithms\n",
    "        self.prefix = prefix\n",
    "        \n",
    "        os.makedirs('./data_feature_distance', exist_ok=True)\n",
    "        \n",
    "        \n",
    "        self.summaries = [\n",
    "            'eval/target_val/accuracy_class_avg',\n",
    "            'eval/target_val/accuracy_class_min',\n",
    "            'eval/target_val/ce_class_avg',            \n",
    "            'alignment_eval_original/ot_sq',\n",
    "            'alignment_eval_original/supp_dist_sq',            \n",
    "            'alignment_eval_original/log_loss',\n",
    "        ]\n",
    "                \n",
    "        \n",
    "        self.results = {\n",
    "            summary_name: list() for summary_name in self.summaries\n",
    "        }\n",
    "        self.results.update({\n",
    "            dist_name: list() for dist_name in ['W1', 'SSD1', 'H1']\n",
    "        })\n",
    "        \n",
    "        for algorithm_name in algorithms:\n",
    "            regex = f'{prefix}/{algorithm_name}[^0_]'\n",
    "            print(regex)\n",
    "            qruns = list(query_runs(api, regex, step=num_steps))\n",
    "            print(f'Q runs: {len(qruns)}')\n",
    "            seeds = set()\n",
    "            runs = []\n",
    "            for run in qruns:\n",
    "                seed = int(run.config['config/training/seed'])\n",
    "                if seed not in seeds:\n",
    "                    seeds.add(seed)\n",
    "                    runs.append(run)\n",
    "            print(f'Runs: {len(runs)}')\n",
    "            for summary_name in self.summaries:                    \n",
    "                self.results[summary_name].append([run.summary.get(summary_name, -1.0) for run in runs])\n",
    "            \n",
    "            distances_list = []\n",
    "            for run in tqdm.notebook.tqdm(runs):\n",
    "                distances = feature_distances(run, step=num_steps)\n",
    "                distances_list.append(distances)\n",
    "            w1_list, ssd1_list, h1_list = zip(*distances_list)\n",
    "            self.results['W1'].append(w1_list)\n",
    "            self.results['SSD1'].append(ssd1_list)\n",
    "            self.results['H1'].append(h1_list)\n",
    "                \n",
    "            print()\n",
    "        \n",
    "    def print_table(self, summaries, summaries_short, agg_fn, tablefmt='pipe', sep=' '):\n",
    "        print(f'{self.prefix}\\n{\" \".join(summaries)}')\n",
    "        columns = ['algorithm'] + summaries_short\n",
    "        table = [columns]\n",
    "        for i, algorithm_name in enumerate(self.algorithms):\n",
    "            table.append([algorithm_name])\n",
    "            for j, summary_name in enumerate(summaries):            \n",
    "                cell = ''\n",
    "                values = self.results[summary_name][i]\n",
    "                if len(values) > 0:\n",
    "                    cell += agg_fn(np.array(values))                \n",
    "                table[-1].append(cell)\n",
    "        print(tabulate.tabulate(table, tablefmt=tablefmt, headers=\"firstrow\"))\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3 classes, dim 2, no dropout, no relu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2/dann_zero[^0_]\n",
      "Q runs: 9\n",
      "Runs: 5\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ba6ac32334204ef5959d6706e7c9e863",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/ot/lp/__init__.py:495: UserWarning: numItermax reached before optimality. Try to increase numItermax.\n",
      "  check_result(result_code)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2/dann[^0_]\n",
      "Q runs: 5\n",
      "Runs: 5\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8ba48b7ab4c0414b8cab41c54536082a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2/support_abs_h0[^0_]\n",
      "Q runs: 5\n",
      "Runs: 5\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "81edcc6977d94e6a92fcea8125cf4111",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2/support_abs_h100[^0_]\n",
      "Q runs: 5\n",
      "Runs: 5\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f6d856c545874d778ba229cee04d86f4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2/support_abs_h500[^0_]\n",
      "Q runs: 5\n",
      "Runs: 5\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7d04f3c091de4a30973d5d168295811a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2/support_abs[^0_]\n",
      "Q runs: 5\n",
      "Runs: 5\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4355446b2f1a4dcfbcf0c4a8833ebbda",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2/support_abs_h2000[^0_]\n",
      "Q runs: 5\n",
      "Runs: 5\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1289f89de6cd40f2b4c43b6e5bb537ba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2/support_abs_h5000[^0_]\n",
      "Q runs: 5\n",
      "Runs: 5\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "601883f3ba9442168d35f9e41dc84d33",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "prefix = ''\n",
    "mid = ''\n",
    "algorithms = [\n",
    "    'dann_zero',\n",
    "    'dann',    \n",
    "    'support_abs_h0',\n",
    "    'support_abs_h100',\n",
    "    'support_abs_h500',\n",
    "    'support_abs', \n",
    "    'support_abs_h2000',\n",
    "    'support_abs_h5000',\n",
    "]\n",
    "\n",
    "# Updated regex pattern: usps_mnist_3c/lenet_2d/seed_[1-5]/s_alpha_15/\n",
    "results_3c_d2 = ResultsTable('0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2', \n",
    "                             algorithms, num_steps=30000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2\n",
      "W1 SSD1\n",
      "| algorithm         | W                     | SSD                   |\n",
      "|:------------------|:----------------------|:----------------------|\n",
      "| dann_zero         | 00.76 00.77 00.84 [5] | 00.10 00.10 00.10 [5] |\n",
      "| dann              | 00.06 00.07 00.08 [5] | 00.02 00.02 00.02 [5] |\n",
      "| support_abs_h0    | 00.22 00.23 00.47 [5] | 00.03 00.03 00.03 [5] |\n",
      "| support_abs_h100  | 00.36 00.55 00.56 [5] | 00.03 00.03 00.03 [5] |\n",
      "| support_abs_h500  | 00.55 00.58 00.64 [5] | 00.03 00.03 00.03 [5] |\n",
      "| support_abs       | 00.56 00.59 00.62 [5] | 00.03 00.03 00.03 [5] |\n",
      "| support_abs_h2000 | 00.58 00.62 00.66 [5] | 00.03 00.03 00.03 [5] |\n",
      "| support_abs_h5000 | 00.63 00.64 00.67 [5] | 00.04 00.04 00.04 [5] |\n"
     ]
    }
   ],
   "source": [
    "columns = [\n",
    "    ('W1', 'W'),\n",
    "    ('SSD1', 'SSD'),\n",
    "]\n",
    "summaries, summaries_short = map(list, zip(*columns))\n",
    "results_3c_d2.print_table(summaries, summaries_short, ResultsTable.quant_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2\n",
      "W1 SSD1\n",
      "| algorithm         | W                      | SSD                    |\n",
      "|:------------------|:-----------------------|:-----------------------|\n",
      "| dann_zero         | $ 0.77_{0.76}^{0.84} $ | $ 0.10_{0.10}^{0.10} $ |\n",
      "| dann              | $ 0.07_{0.06}^{0.08} $ | $ 0.02_{0.02}^{0.02} $ |\n",
      "| support_abs_h0    | $ 0.23_{0.22}^{0.47} $ | $ 0.03_{0.03}^{0.03} $ |\n",
      "| support_abs_h100  | $ 0.55_{0.36}^{0.56} $ | $ 0.03_{0.03}^{0.03} $ |\n",
      "| support_abs_h500  | $ 0.58_{0.55}^{0.64} $ | $ 0.03_{0.03}^{0.03} $ |\n",
      "| support_abs       | $ 0.59_{0.56}^{0.62} $ | $ 0.03_{0.03}^{0.03} $ |\n",
      "| support_abs_h2000 | $ 0.62_{0.58}^{0.66} $ | $ 0.03_{0.03}^{0.03} $ |\n",
      "| support_abs_h5000 | $ 0.64_{0.63}^{0.67} $ | $ 0.04_{0.04}^{0.04} $ |\n"
     ]
    }
   ],
   "source": [
    "columns = [\n",
    "    ('W1', 'W'),\n",
    "    ('SSD1', 'SSD'),    \n",
    "]\n",
    "summaries, summaries_short = map(list, zip(*columns))\n",
    "results_3c_d2.print_table(summaries, summaries_short, ResultsTable.quant_or_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "| algorithm         | W                      | SSD                    |\n",
    "|:------------------|:-----------------------|:-----------------------|\n",
    "| dann_zero         | $ 0.77_{0.76}^{0.84} $ | $ 0.10_{0.10}^{0.10} $ |\n",
    "| dann              | $ 0.07_{0.06}^{0.08} $ | $ 0.02_{0.02}^{0.02} $ |\n",
    "| support_abs_h0    | $ 0.23_{0.22}^{0.47} $ | $ 0.03_{0.03}^{0.03} $ |\n",
    "| support_abs_h100  | $ 0.55_{0.36}^{0.56} $ | $ 0.03_{0.03}^{0.03} $ |\n",
    "| support_abs_h500  | $ 0.58_{0.55}^{0.64} $ | $ 0.03_{0.03}^{0.03} $ |\n",
    "| support_abs       | $ 0.59_{0.56}^{0.62} $ | $ 0.03_{0.03}^{0.03} $ |\n",
    "| support_abs_h2000 | $ 0.62_{0.58}^{0.66} $ | $ 0.03_{0.03}^{0.03} $ |\n",
    "| support_abs_h5000 | $ 0.64_{0.63}^{0.67} $ | $ 0.04_{0.04}^{0.04} $ |\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2\n",
      "eval/target_val/accuracy_class_avg eval/target_val/accuracy_class_min W1 SSD1\n",
      "\\begin{tabular}{lllll}\n",
      "\\hline\n",
      " algorithm         & acc (avg)                & acc (min)                & W                        & SSD                      \\\\\n",
      "\\hline\n",
      " dann_zero         & $ 63.0_{~62.3}^{~69.6} $ & $ 45.3_{~37.9}^{~53.6} $ & $ 00.8_{~00.8}^{~00.8} $ & $ 00.1_{~00.1}^{~00.1} $ \\\\\n",
      " dann              & $ 75.6_{~72.4}^{~83.7} $ & $ 54.8_{~49.6}^{~55.1} $ & $ 00.1_{~00.1}^{~00.1} $ & $ 00.0_{~00.0}^{~00.0} $ \\\\\n",
      " support_abs_h0    & $ 73.9_{~73.4}^{~84.1} $ & $ 61.8_{~54.6}^{~72.4} $ & $ 00.2_{~00.2}^{~00.5} $ & $ 00.0_{~00.0}^{~00.0} $ \\\\\n",
      " support_abs_h100  & $ 88.5_{~86.8}^{~95.1} $ & $ 71.4_{~70.6}^{~93.3} $ & $ 00.5_{~00.4}^{~00.6} $ & $ 00.0_{~00.0}^{~00.0} $ \\\\\n",
      " support_abs_h500  & $ 94.5_{~88.7}^{~94.7} $ & $ 89.0_{~83.1}^{~90.3} $ & $ 00.6_{~00.6}^{~00.6} $ & $ 00.0_{~00.0}^{~00.0} $ \\\\\n",
      " support_abs       & $ 91.1_{~91.1}^{~93.0} $ & $ 85.6_{~80.7}^{~86.2} $ & $ 00.6_{~00.6}^{~00.6} $ & $ 00.0_{~00.0}^{~00.0} $ \\\\\n",
      " support_abs_h2000 & $ 94.0_{~91.2}^{~94.7} $ & $ 88.6_{~80.2}^{~89.4} $ & $ 00.6_{~00.6}^{~00.7} $ & $ 00.0_{~00.0}^{~00.0} $ \\\\\n",
      " support_abs_h5000 & $ 82.1_{~81.8}^{~83.9} $ & $ 68.9_{~65.5}^{~70.9} $ & $ 00.6_{~00.6}^{~00.7} $ & $ 00.0_{~00.0}^{~00.0} $ \\\\\n",
      "\\hline\n",
      "\\end{tabular}\n"
     ]
    }
   ],
   "source": [
    "columns = [\n",
    "    ('eval/target_val/accuracy_class_avg', 'acc (avg)'),\n",
    "    ('eval/target_val/accuracy_class_min', 'acc (min)'),\n",
    "    ('W1', 'W'),\n",
    "    ('SSD1', 'SSD'),     \n",
    "    \n",
    "]\n",
    "summaries, summaries_short = map(list, zip(*columns))\n",
    "results_3c_d2.print_table(summaries, summaries_short, ResultsTable.quant_latex_fn, \n",
    "                          tablefmt='latex_raw', sep=' & ')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
